from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from matplotlib import pyplot as plt


TAKE_HOME_MESSAGE = (
    "More media/information freedom -> greater transparency during COVID "
    "(smaller excess-vs-reported mortality gaps)."
)


@dataclass(frozen=True)
class Window:
    name: str
    start: str  # YYYY-MM
    end: str  # YYYY-MM


def month_to_ts(month_id: str) -> pd.Timestamp:
    return pd.Period(month_id, freq="M").to_timestamp()


def aggregate_country_window(
    panel: pd.DataFrame,
    window: Window,
    outcome: str,
    min_months_nonmissing: int = 12,
) -> pd.DataFrame:
    p = panel.copy()
    p["month"] = pd.to_datetime(p["month"])
    p["month_id"] = p["month"].dt.to_period("M").astype(str)

    start_ts = month_to_ts(window.start)
    end_ts = month_to_ts(window.end)
    w = p[(p["month"] >= start_ts) & (p["month"] <= end_ts)].copy()

    w = w[w["media_freedom"].notna()].copy()

    nm = w.groupby("iso_code")[outcome].apply(lambda s: int(s.notna().sum())).rename("n_months_nonmissing")
    w = w.merge(nm.reset_index(), on="iso_code", how="left")
    w = w[w["n_months_nonmissing"] >= min_months_nonmissing].copy()

    controls = [
        "population",
        "gdp_per_capita",
        "median_age",
        "hospital_beds_per_thousand",
        "diabetes_prevalence",
        "human_development_index",
        "death_reg_cod_pct",
        "wgi_gov_effectiveness",
    ]
    controls = [c for c in controls if c in w.columns]

    agg = (
        w.groupby(["iso_code"], as_index=False)
        .agg(
            y=(outcome, "sum"),
            media_freedom=("media_freedom", "max"),
            location=("location", "max"),
            continent=("continent", "max"),
            **{c: (c, "max") for c in controls},
        )
        .copy()
    )
    if "gdp_per_capita" in agg.columns:
        agg["log_gdp_pc"] = np.log(agg["gdp_per_capita"])
    agg["window"] = window.name
    return agg


def main_scatter(
    agg: pd.DataFrame,
    out_path: Path,
    y_label: str,
    title: str,
) -> None:
    d = agg.copy()
    d = d.dropna(subset=["y", "media_freedom"]).copy()

    plt.figure(figsize=(9.5, 6.5))
    sns.regplot(
        data=d,
        x="media_freedom",
        y="y",
        scatter_kws={"alpha": 0.45, "s": 22},
        line_kws={"color": "black", "linewidth": 2},
        ci=95,
    )
    plt.title(f"{title}\n{TAKE_HOME_MESSAGE}")
    plt.xlabel("Media freedom (V-Dem v2x_freexp_altinf, baseline <=2019)")
    plt.ylabel(y_label)
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=250)
    plt.close()


def thresholds_plot(thresholds_csv: Path, out_path: Path) -> None:
    df = pd.read_csv(thresholds_csv)
    d = df.copy()
    d["ci_lo"] = d["coef"] - 1.96 * d["se"]
    d["ci_hi"] = d["coef"] + 1.96 * d["se"]
    d = d.sort_values("threshold_months_nonmissing_2023")

    plt.figure(figsize=(8.5, 5.5))
    plt.errorbar(
        d["threshold_months_nonmissing_2023"],
        d["coef"],
        yerr=1.96 * d["se"],
        fmt="o-",
        color="black",
        capsize=3,
    )
    plt.axhline(0, color="gray", linewidth=1)
    plt.xticks([0, 6, 9, 12], [">=0", ">=6", ">=9", ">=12"])
    plt.xlabel("Required non-missing months in 2023 (balancedness)")
    plt.ylabel("Coef on media freedom (robust SE, 95% CI)")
    plt.title("2020-2023 robustness: enforce 2023 coverage threshold\n(Outcome: transparency gap, per million)")
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=250)
    plt.close()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--panel", type=Path, default=Path("outputs/data/panel_merged.parquet"))
    parser.add_argument("--out-dir", type=Path, default=Path("submission"))
    args = parser.parse_args()

    out_dir = args.out_dir
    (out_dir / "figures").mkdir(parents=True, exist_ok=True)

    panel = pd.read_parquet(args.panel)

    w_main = Window("2020-2022", "2020-02", "2022-12")
    agg_gap = aggregate_country_window(panel, w_main, outcome="gap_pm_excess_minus_reported", min_months_nonmissing=12)
    main_scatter(
        agg_gap,
        out_dir / "figures" / "fig1_scatter_transparency_gap_2020_2022.png",
        y_label="Transparency gap (excess deaths - reported COVID deaths), per million\n(sum over 2020-2022)",
        title="Figure 1. Media freedom and transparency during COVID (2020-2022)",
    )

    thresholds_plot(
        thresholds_csv=Path("outputs/tables/robustness_2023_completeness.csv"),
        out_path=out_dir / "figures" / "fig2_robustness_2023_thresholds.png",
    )

    # Multi-panel figures for submission-style layout
    # Figure 1 panels: (A) gap scatter, (B) reported scatter, (C) excess scatter, (D) monthly slope over time.
    w = w_main
    agg_reported = aggregate_country_window(panel, w, outcome="reported_deaths_pm", min_months_nonmissing=12)
    agg_excess = aggregate_country_window(panel, w, outcome="excess_deaths_pm", min_months_nonmissing=12)
    # Align samples across panels using the gap sample countries.
    keep = set(agg_gap["iso_code"].tolist())
    agg_reported = agg_reported[agg_reported["iso_code"].isin(keep)].copy()
    agg_excess = agg_excess[agg_excess["iso_code"].isin(keep)].copy()

    fig1_path = out_dir / "figures" / "fig1_panels.png"
    plt.figure(figsize=(12, 9))

    ax1 = plt.subplot(2, 2, 1)
    sns.regplot(
        data=agg_gap.dropna(subset=["y", "media_freedom"]),
        x="media_freedom",
        y="y",
        scatter_kws={"alpha": 0.45, "s": 18},
        line_kws={"color": "black", "linewidth": 2},
        ci=95,
        ax=ax1,
    )
    ax1.set_title("Panel A. Transparency gap (2020-2022)")
    ax1.set_xlabel("Information freedom (V-Dem, baseline <=2019)")
    ax1.set_ylabel("Gap per million (excess - reported)")

    ax2 = plt.subplot(2, 2, 2)
    sns.regplot(
        data=agg_reported.dropna(subset=["y", "media_freedom"]),
        x="media_freedom",
        y="y",
        scatter_kws={"alpha": 0.45, "s": 18},
        line_kws={"color": "black", "linewidth": 2},
        ci=95,
        ax=ax2,
    )
    ax2.set_title("Panel B. Reported COVID-19 deaths (2020-2022)")
    ax2.set_xlabel("Information freedom (V-Dem, baseline <=2019)")
    ax2.set_ylabel("Reported deaths per million (sum)")

    ax3 = plt.subplot(2, 2, 3)
    sns.regplot(
        data=agg_excess.dropna(subset=["y", "media_freedom"]),
        x="media_freedom",
        y="y",
        scatter_kws={"alpha": 0.45, "s": 18},
        line_kws={"color": "black", "linewidth": 2},
        ci=95,
        ax=ax3,
    )
    ax3.set_title("Panel C. Excess deaths (2020-2022)")
    ax3.set_xlabel("Information freedom (V-Dem, baseline <=2019)")
    ax3.set_ylabel("Excess deaths per million (sum)")

    ax4 = plt.subplot(2, 2, 4)
    coef_path = Path("outputs/tables/panel_gap_pm_excess_minus_reported_event_slope_coeffs.csv")
    if coef_path.exists():
        c = pd.read_csv(coef_path)
        c["month"] = pd.to_datetime(c["month"])
        ax4.plot(c["month"], c["coef"], color="black", linewidth=1.5)
        ax4.fill_between(c["month"], c["ci_lo"], c["ci_hi"], color="black", alpha=0.15)
        ax4.axhline(0, color="gray", linewidth=1)
        ax4.set_title("Panel D. Monthly association (adjusted)")
        ax4.set_xlabel("Month")
        ax4.set_ylabel("Gap per +1 info freedom")
    else:
        ax4.axis("off")
        ax4.text(0.5, 0.5, "Monthly slope file not found", ha="center", va="center")

    plt.tight_layout()
    fig1_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(fig1_path, dpi=250)
    plt.close()

    # Figure 2 panels: robustness checks (alt-X, leave-one-continent-out, 2023 threshold).
    fig2_path = out_dir / "figures" / "fig2_panels.png"
    plt.figure(figsize=(13, 4.5))

    # Panel A: alternative X measures (2020-2022 gap, full controls).
    axa = plt.subplot(1, 3, 1)
    alt_path = Path("outputs/tables/robustness_alt_x.csv")
    if alt_path.exists():
        alt = pd.read_csv(alt_path)
        alt = alt[(alt["window"] == "2020-2022") & (alt["outcome"] == "gap_pm_excess_minus_reported")].copy()
        alt = alt[alt["spec"].astype(str).str.contains("OLS_full")].copy()
        alt = alt[alt["x_var"].isin(["media_freedom", "freedom_expression", "alternative_info"])].copy()
        alt["ci_lo"] = alt["coef"] - 1.96 * alt["se"]
        alt["ci_hi"] = alt["coef"] + 1.96 * alt["se"]
        order = ["media_freedom", "freedom_expression", "alternative_info"]
        alt["x_lab"] = alt["x_var"].map(
            {"media_freedom": "Composite", "freedom_expression": "Expression", "alternative_info": "Alt. info"}
        )
        alt["ord"] = alt["x_var"].map({k: i for i, k in enumerate(order)})
        alt = alt.sort_values("ord")
        y = np.arange(len(alt))
        axa.errorbar(alt["coef"], y, xerr=1.96 * alt["se"], fmt="o", color="black", ecolor="black", capsize=2)
        axa.axvline(0, color="gray", linewidth=1)
        axa.set_yticks(y, alt["x_lab"].tolist())
        axa.set_title("Panel A. Alt. info measures")
        axa.set_xlabel("Coef (95% CI)")
    else:
        axa.axis("off")
        axa.text(0.5, 0.5, "robustness_alt_x.csv not found", ha="center", va="center")

    # Panel B: leave-one-continent-out
    axb = plt.subplot(1, 3, 2)
    loo_path = Path("outputs/tables/robustness_leave_one_continent.csv")
    if loo_path.exists():
        loo = pd.read_csv(loo_path)
        loo = loo[(loo["window"] == "2020-2022") & (loo["outcome"] == "gap_pm_excess_minus_reported")].copy()
        loo["ci_lo"] = loo["coef"] - 1.96 * loo["se"]
        loo["ci_hi"] = loo["coef"] + 1.96 * loo["se"]
        loo = loo.sort_values("excluded_continent")
        y = np.arange(len(loo))
        axb.errorbar(loo["coef"], y, xerr=1.96 * loo["se"], fmt="o", color="black", ecolor="black", capsize=2)
        axb.axvline(0, color="gray", linewidth=1)
        axb.set_yticks(y, loo["excluded_continent"].tolist())
        axb.set_title("Panel B. Leave-one-continent-out")
        axb.set_xlabel("Coef (95% CI)")
    else:
        axb.axis("off")
        axb.text(0.5, 0.5, "robustness_leave_one_continent.csv not found", ha="center", va="center")

    # Panel C: 2023 completeness threshold
    axc = plt.subplot(1, 3, 3)
    thr_path = Path("outputs/tables/robustness_2023_completeness.csv")
    if thr_path.exists():
        thr = pd.read_csv(thr_path)
        thr = thr.sort_values("threshold_months_nonmissing_2023")
        axc.errorbar(
            thr["threshold_months_nonmissing_2023"],
            thr["coef"],
            yerr=1.96 * thr["se"],
            fmt="o-",
            color="black",
            capsize=3,
        )
        axc.axhline(0, color="gray", linewidth=1)
        axc.set_xticks([0, 6, 9, 12], [">=0", ">=6", ">=9", ">=12"])
        axc.set_title("Panel C. 2023 coverage threshold")
        axc.set_xlabel("Required non-missing months in 2023")
        axc.set_ylabel("Coef (95% CI)")
    else:
        axc.axis("off")
        axc.text(0.5, 0.5, "robustness_2023_completeness.csv not found", ha="center", va="center")

    plt.tight_layout()
    fig2_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(fig2_path, dpi=250)
    plt.close()

    print(f"Wrote figures to {out_dir / 'figures'}")


if __name__ == "__main__":
    main()
